from keras.datasets import mnist
import matplotlib.pyplot as plt

(X_train, y_train), (X_test, y_test) = mnist.load_data()

num_train_images = X_train.shape[0]
num_test_images = X_test.shape[0]
image_height = X_train.shape[1]
image_width = X_train.shape [2]

print("Shape: " + str(X_train.shape))
print("Training images: " + str(num_train_images))
print("Image height: " + str(image_height))
print("Image width: " + str(image_width))

def plot(X):
    fig, axs = plt.subplots(1,12, figsize=(17,6))
    for i in range(12):
        axs[i].imshow(X[i], cmap = plt.get_cmap('gray'))
        axs[i].axis('off')
    plt.show()
    
X = X_test[:12,:]
plot(X)
